# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Init for base architecture engine monitor register. """

import time
import os
import stat
from typing import Optional, Union, Iterable
import numpy as np

import luojianet_ms as ms
from luojianet_ms import save_checkpoint
from luojianet_ms.train.callback import Callback

#from checkparam import Rel, Validator as validator
from check_param import Rel,Validator as validator
__all__ = ["LossMonitor", "ValAccMonitor"]


class LossMonitor(Callback):
    """
    Loss Monitor for classification.

    Args:
        lr_init (Union[float, Iterable], optional): The learning rate schedule. Default: None.
        per_print_times (int): Every how many steps to print the log information. Default: 1.

    Examples:
    from mindvision.engine.callback import LossMonitor
    lr = [0.01, 0.008, 0.006, 0.005, 0.002]
    monitor = LossMonitor(lr_init=lr, per_print_times=100)
    """

    def __init__(self,
                 lr_init: Optional[Union[float, Iterable]] = None,
                 per_print_times: int = 1):
        super(LossMonitor, self).__init__()
        self.lr_init = lr_init
        self.per_print_times = per_print_times
        self.last_print_time = 0

    # pylint: disable=unused-argument





def epoch_begin(self, run_context):
    """
    Record time at the beginning of epoch.

    Args:
        run_context (RunContext): Context of the process running.
    """
    self.losses = []
    self.epoch_time = time.time()




def epoch_end(self, run_context):
    """
    Print training info at the end of epoch.

    Args:
        run_context (RunContext): Context of the process running.
    """
    callback_params = run_context.original_args()
    epoch_mseconds = (time.time() - self.epoch_time) * 1000
    per_step_mseconds = epoch_mseconds / callback_params.batch_num
    print(f"Epoch time: {epoch_mseconds:5.3f} ms, "
          f"per step time: {per_step_mseconds:5.3f} ms, "
          f"avg loss: {np.mean(self.losses):5.3f}", flush=True)


# pylint: disable=unused-argument


def step_begin(self, run_context):
    """
    Record time at the beginning of step.

    Args:
        run_context (RunContext): Context of the process running.
    """
    self.step_time = time.time()




def step_end(self, run_context):
    """
    Print training info at the end of step.

    Args:
        run_context (RunContext): Context of the process running.
    """
    callback_params = run_context.original_args()
    step_mseconds = (time.time() - self.step_time) * 1000
    loss = callback_params.net_outputs

    if isinstance(loss, (tuple, list)):
        if isinstance(loss[0], ms.Tensor) and isinstance(loss[0].asnumpy(), np.ndarry):
            loss = loss[0]

    if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray):
        loss = np.mean(loss.asnumpy())

    self.losses.append(loss)
    cur_step_in_epoch = (callback_params.cur_step_num - 1) % callback_params.batch_num + 1

    # Boundary check.
    if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
        raise ValueError(f"Invalid loss, terminate training.")

    def print_info():
        lr_output = self.lr_init[callback_params.cur_step_num - 1] if isinstance(self.lr_init,
                                                                                 list) else self.lr_init
        print(f"Epoch:[{(callback_params.cur_epoch_num - 1):3d}/{callback_params.epoch_num:3d}], "
              f"step:[{cur_step_in_epoch:5d}/{callback_params.batch_num:5d}], "
              f"loss:[{loss:5.3f}/{np.mean(self.losses):5.3f}], "
              f"time:{step_mseconds:5.3f} ms, "
              f"lr:{lr_output:5.5f}", flush=True)

    if (callback_params.cur_step_num - self.last_print_time) >= self.per_print_times:
        self.last_print_time = callback_params.cur_step_num
        print_info()





class ValAccMonitor(Callback):
    """
    Monitors the train loss and the validation accuracy, after each epoch saves the
    best checkpoint file with highest validation accuracy.

    Args:
        model (ms.Model): The model to monitor.
        dataset_val (ms.dataset): The dataset that the model needs.
        num_epochs (int): The number of epochs.
        interval (int): Every how many epochs to validate and print information. Default: 1.
        eval_start_epoch (int): From which time to validate. Default: 1.
        save_best_ckpt (bool): Whether to save the checkpoint file which performs best. Default: True.
        ckpt_directory (str): The path to save checkpoint files. Default: './'.
        best_ckpt_name (str): The file name of the checkpoint file which performs best. Default: 'best.ckpt'.
        metric_name (str): The name of metric for model evaluation. Default: 'Accuracy'.
        dataset_sink_mode (bool): Whether to use the dataset sinking mode. Default: True.

    Raises:
        ValueError: If `interval` is not more than 1.

    Examples:
    import luojianet_ms as ms
    import luojianet_ms.nn as nn
    import luojianet_ms.dataset as ds
    from mindvision.classification.models import lenet
    from mindvision.classification.dataset import Mnist

    net = lenet()
    opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.9)
    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
    model = ms.Model(net, loss,opt,metrics={"Accuracy":nn.Accuracy()})
    dataset_val = Mnist("./mnist", split="test", batch_size=32, resize=32, download=True)
    dataset_val = dataset_val.run()
    monitor = ValAccMonitor(model, dataset_val, num_epochs=10)
    """

    def __init__(self,
                 model: ms.Model,
                 dataset_val: ms.dataset,
                 num_epochs: int,
                 interval: int = 1,
                 eval_start_epoch: int = 1,
                 save_best_ckpt: bool = True,
                 ckpt_directory: str = "./",
                 best_ckpt_name: str = "best.ckpt",
                 metric_name: str = "Accuracy",
                 dataset_sink_mode: bool = True):
        super(ValAccMonitor, self).__init__()
        self.model = model
        self.dataset_val = dataset_val
        self.num_epochs = num_epochs
        self.eval_start_epoch = eval_start_epoch
        self.save_best_ckpt = save_best_ckpt
        self.metric_name = metric_name
        self.interval = validator.check_int(interval, 1, Rel.GE, "interval")
        self.best_res = 0
        self.dataset_sink_mode = dataset_sink_mode

        if not os.path.isdir(ckpt_directory):
            os.makedirs(ckpt_directory)
        self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)





def apply_eval(self):
    """Model evaluation, return validation accuracy."""
    return self.model.eval(self.dataset_val, dataset_sink_mode=self.dataset_sink_mode)[self.metric_name]





def epoch_end(self, run_context):
    """
    After epoch, print train loss and val accuracy,
    save the best ckpt file with highest validation accuracy.

    Args:
        run_context (RunContext): Context of the process running.
    """
    callback_params = run_context.original_args()
    cur_epoch = callback_params.cur_epoch_num

    if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
        # Validation result
        res = self.apply_eval()

        print("-" * 20)
        print(f"Epoch: [{cur_epoch: 3d} / {self.num_epochs: 3d}], "
              f"Train Loss: [{callback_params.net_outputs.asnumpy() :5.3f}], "
              f"{self.metric_name}: {res: 5.3f}")

        def remove_ckpt_file(file_name):
            os.chmod(file_name, stat.S_IWRITE)
            os.remove(file_name)

        # Save the best ckpt file
        if res >= self.best_res:
            self.best_res = res
            if self.save_best_ckpt:
                if os.path.exists(self.best_ckpt_path):
                    remove_ckpt_file(self.best_ckpt_path)
                save_checkpoint(callback_params.train_network, self.best_ckpt_path)


# pylint: disable=unused-argument



def end(self, run_context):
    """
    Print the best validation accuracy after network training.

    Args:
        run_context (RunContext): Context of the process running.
    """
    print("=" * 80)
    print(f"End of validation the best {self.metric_name} is: {self.best_res: 5.3f}, "
          f"save the best ckpt file in {self.best_ckpt_path}", flush=True)
